#!/usr/bin/env python3

import numpy as np
import numpy_extra as npe
import rospy
import threading

from ipm import IPM

from cv_bridge import *
from transformations import euler_from_quaternion, quaternion_from_euler
from sensor_msgs.msg import *
from std_msgs.msg import *
from nav_msgs.msg import *
from geometry_msgs.msg import *

DEBUG = "--debug" in sys.argv

if DEBUG:
    import cv2 # for imshow

class LocalMapper(object):
    def __init__(self,
      node_name = "local_mapper_node",
      ns_camera = "/camera_front",
      ns_imu = "/imu",
      ns_vehicle = "/vehicle",
      pixels_per_meter = 8.0,
      configuration = "camera_ethernet_imx291"
    ):
        self.is_initialized = False
        rospy.init_node(node_name)

        # parameters

        self.pixels_per_meter = pixels_per_meter

        # publishers

        self.pub_grid = rospy.Publisher("semantic_grid", OccupancyGrid, queue_size = 1)
        self.pub_image = rospy.Publisher("semantic_image", Image, queue_size = 1)

        # subscribers

        self.sub_semantic = rospy.Subscriber(
            "%s/semantic" % ns_camera, Image, self.on_semantic)
        self.sub_imu_data = rospy.Subscriber(
            "%s/data" % ns_imu, Imu, self.on_imu_data)
        self.sub_speed = rospy.Subscriber(
            "%s/speed" % ns_vehicle, Float32, self.on_speed)

        self.ground_plane_r, self.ground_plane_theta = None, None

        self.is_initialized = False
        while not rospy.has_param("%s/semantic_categories" % ns_camera):
            rospy.loginfo_throttle(10, "waiting for categories")
            rospy.sleep(1)

        self.categories = rospy.get_param("%s/semantic_categories" % ns_camera)

        self.color_map = np.array([list(reversed(c["color"])) for c in self.categories] + [[255,255,255]],
            dtype=np.uint8) # last is a visualization "category"

        self.local_map = np.zeros((1024, 1024, len(self.categories)))

        # argmaxed version of above
        self.local_map_flat = np.zeros((1024, 1024), dtype = np.uint8)

        self.map_center_x = 0.0
        self.map_center_y = 0.0

        self.cur_y = 0.0
        self.cur_x = -self.local_map.shape[0] / self.pixels_per_meter * 0.5
        self.cur_yaw = 0.0
        self.velocity_x = 0.0

        self.ipm = IPM(pixels_per_meter = self.pixels_per_meter, configuration = configuration)
        self.lock = threading.Lock()
        self.is_initialized = True

    def on_speed(self, msg):
        self.velocity_x = msg.data / 3.6 # km/h to m/s

    def on_imu_data(self, msg):
        pass

    def on_semantic(self, msg):
        self.lock.acquire()
        logits = imgmsg_to_cv2(msg)
        if self.ground_plane_r is None or self.ground_plane_theta is None:
            self.ground_plane_r, self.ground_plane_theta = self.ipm.get_ground_plane(logits.shape)

        cutoff_top = int(logits.shape[0] * 0.4)
        cutoff_bottom = int(logits.shape[0] * 0.9)

        logits = logits[cutoff_top:cutoff_bottom,:]
        ground_plane_r = self.ground_plane_r[cutoff_top:cutoff_bottom,:]
        ground_plane_theta = self.ground_plane_theta[cutoff_top:cutoff_bottom,:]

        values_y = self.cur_y + ground_plane_r * np.sin(ground_plane_theta + self.cur_yaw)
        values_x = self.cur_x + ground_plane_r * np.cos(ground_plane_theta + self.cur_yaw)

        values_py = (values_y * self.pixels_per_meter + self.local_map.shape[0] / 2.0).astype(np.int16)
        values_px = (values_x * self.pixels_per_meter + self.local_map.shape[1] / 2.0).astype(np.int16)

        where_valid = (values_py >= 0) & (values_px >= 0) & \
            (values_py < self.local_map.shape[0]) & (values_px < self.local_map.shape[1])

        self.local_map *= 0

        values_py = values_py[where_valid]
        values_px = values_px[where_valid]
        logits = logits[where_valid]

        #self.local_map *= 0

        semantic_weights = np.array([1,100,100], dtype=np.uint16)

        np.add.at(self.local_map, \
            (values_py, values_px, logits),
            semantic_weights[logits],
        )

        self.local_map_flat = np.argmax(self.local_map, axis=2).astype(np.uint8)

        if self.pub_grid.get_num_connections() > 0:
            msg_grid = OccupancyGrid()
            msg_grid.header.frame_id = "base_link"
            msg_grid.info.map_load_time = rospy.get_rostime()
            msg_grid.info.resolution = float(1.0/self.pixels_per_meter)
            msg_grid.info.width = int(self.local_map.shape[1])
            msg_grid.info.height = int(self.local_map.shape[0])
            msg_grid.info.origin.position.x = float(self.cur_x + self.local_map.shape[1]/self.pixels_per_meter/2.0)
            msg_grid.info.origin.position.y = float(self.cur_y + self.local_map.shape[0]/self.pixels_per_meter/2.0)
            q = quaternion_from_euler(0.0, 0.0, self.cur_yaw)
            msg_grid.info.origin.orientation.w = q[0]
            msg_grid.info.origin.orientation.x = q[1]
            msg_grid.info.origin.orientation.y = q[2]
            msg_grid.info.origin.orientation.z = q[3]
            msg_grid.data = np.ascontiguousarray(self.local_map_flat).tostring()
            self.pub_grid.publish(msg_grid)

        if self.pub_image.get_num_connections() > 0:
            py = int(self.cur_y * self.pixels_per_meter + self.local_map.shape[0] / 2)
            px = int(self.cur_x * self.pixels_per_meter + self.local_map.shape[1] / 2)

            msg_image = cv2_to_imgmsg(self.color_map[self.local_map_flat[::-1,:]], encoding = "bgr8")
            self.pub_image.publish(msg_image)

        if DEBUG:
            py = int(self.cur_y * self.pixels_per_meter + self.local_map.shape[0] / 2)
            px = int(self.cur_x * self.pixels_per_meter + self.local_map.shape[1] / 2)

            cv2.imshow('local_map', self.color_map[self.local_map_flat[::-1,:]])
            cv2.waitKey(1)

        self.lock.release()

    def publish(self):
        pass

    def shift_map(self, shift_x, shift_y):
        self.lock.acquire()
        shift_px = int(shift_x * self.pixels_per_meter)
        shift_py = int(shift_y * self.pixels_per_meter)
        self.cur_x -=shift_x
        self.cur_y -=shift_y
        self.local_map = npe.shift_2d_replace(self.local_map, -shift_px, -shift_py, 0)
        self.lock.release()

    def spin(self):
        #rospy.spin()
        rate = rospy.Rate(50)
        seq = 0
        while not rospy.is_shutdown():
            seq += 1
            rate.sleep()
            if not self.is_initialized:
                continue

            if seq % 100 != 0:
                continue

if __name__ == "__main__":
    local_mapper = LocalMapper()
    local_mapper.spin()

